import csv
import numpy as np
from loader import load_animal_data

def main():
    animal_pmf, question_data = load_animal_data('animals.csv')
    normalize_pmf(animal_pmf)
   

    while not is_certain(animal_pmf):
        print("--------------------------")
        print_pmf(animal_pmf)
        print(f"=> Uncertainty: {compute_uncertainty(animal_pmf):.2f}")
        best_question = chose_best_question(animal_pmf, question_data)
        answer = parse_yes_no(input(f"{best_question} "))
        animal_pmf = compute_posterior(animal_pmf, question_data[best_question], answer)
        
        
    print("====================================\n")
    print(f"=> Uncertainty: {compute_uncertainty(animal_pmf):.2f}")
    print(f"Your animal is a {max(animal_pmf, key=animal_pmf.get)}")

def chose_best_question(animal_pmf, question_data):
    # select the question which minimizes the expected uncertainty
    best_question = None
    best_uncertainty = None
    for question in question_data:
        expected_uncertainty = compute_expected_uncertainty(animal_pmf, question_data[question])
        if best_question is None or expected_uncertainty < best_uncertainty:
            best_question = question
            best_uncertainty = expected_uncertainty
    return best_question

def compute_expected_uncertainty(animal_pmf, question_data):
    # compute the entropy of the distribution
    prob_yes = compute_prop_yes(animal_pmf, question_data)
    prob_no = 1 - prob_yes

    pmf_yes = compute_posterior(animal_pmf, question_data, True)
    pmf_no = compute_posterior(animal_pmf, question_data, False)

    uncertainty_yes = compute_uncertainty(pmf_yes)
    uncertainty_no = compute_uncertainty(pmf_no)

    return prob_yes * uncertainty_yes + prob_no * uncertainty_no

def compute_uncertainty(animal_pmf):
    entropy = 0
    for animal_i in animal_pmf:
        p_i = animal_pmf[animal_i]
        entropy += np.log2(1/p_i) * p_i
    return entropy

def compute_posterior(animal_pmf, question_data, answer):
    # compute the posterior distribution
    pmf = {}
    for animal in animal_pmf:
        if question_data[animal] == answer:
            pmf[animal] = animal_pmf[animal]
    normalize_pmf(pmf)
    return pmf

def compute_prop_yes(animal_pmf, question_data):
    # compute the probability of the animal being a yes
    prob_yes = 0
    for animal in animal_pmf:
        if question_data[animal]:
            prob_yes += animal_pmf[animal]
    return prob_yes

def is_certain(animal_pmf):
    # if any animal has a probability of 1, we are certain
    for animal in animal_pmf:
        if animal_pmf[animal] == 1:
            return True
    return False

def parse_yes_no(response):
    if response == 'yes':
        return True
    if response == 'no':
        return False
    raise ValueError("Response must be 'yes' or 'no'")
    
def normalize_pmf(animal_pmf):
    total = sum(animal_pmf.values())
    for animal in animal_pmf:
        animal_pmf[animal] /= total

def print_pmf(animal_pmf):
    for animal in animal_pmf:
        print(f"{animal} {animal_pmf[animal]:.2f}")

if __name__ == '__main__':
    main()